package com.hhf.ds.util;

import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.dom4j.Document;
import org.dom4j.DocumentException;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.springframework.expression.EvaluationContext;
import org.springframework.expression.Expression;
import org.springframework.expression.ExpressionParser;
import org.springframework.expression.common.CompositeStringExpression;
import org.springframework.expression.common.TemplateParserContext;
import org.springframework.expression.spel.standard.SpelExpression;
import org.springframework.expression.spel.standard.SpelExpressionParser;
import org.springframework.expression.spel.support.StandardEvaluationContext;

import java.io.InputStream;
import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

/**
 * @author haohaifeng
 * @date 2021/1/15 10:36
 */
@Slf4j
public class SqlParser {
    private SqlXml sqlXml;

    public SqlParser(String rptCode, InputStream inputStream) {
        this.sqlXml = parseXml(rptCode, inputStream);
    }

    @Data
    private static class SqlXml {
        private String rptCode;
        private String rptName;
        private String colNames;
        private String show;
        private String sql;
    }

    public SqlXml parseXml(String rptCode, InputStream inputStream) {
        SAXReader saxReader = new SAXReader();
        Document doc = null;
        try {

            doc = saxReader.read(inputStream);
        } catch (DocumentException e) {
            log.error(e.getMessage(), e);
        }
        Element root = doc.getRootElement();
        Element sqlGroup = root.element(rptCode);
        CoreAssert.notNull(sqlGroup, "rptCode不存在");
        SqlXml xml = new SqlXml();
        xml.setShow(root.attribute("show").getValue());
        xml.setSql(sqlGroup.elementText("sql"));
        xml.setColNames(sqlGroup.elementText("colNames"));
        xml.setRptName(sqlGroup.elementText("rptName"));
        xml.setRptCode(rptCode);
        return xml;
    }

    /**
     * 获取SQL主体.
     * @return
     */
    public String getSQL() {
        String sql = sqlXml.getSql();
        String show = sqlXml.getShow();
        String rptCode = sqlXml.getRptCode();
        if ("true".equals(show)) {
            log.info("{} : {}",rptCode,sql);
        }
        return sql;
    }

    /**
     * 解析参数替换符后获取带参SQL.
     * @param obj  参数接收对象
     * @return 解析拼接参数后的sql
     */
    public String parseSQL(Object obj) {
        String sql = sqlXml.getSql();
        String show = sqlXml.getShow();
        String rptCode = sqlXml.getRptCode();
        try {
            sql = parseSql(sql, obj);
            if ("true".equals(show)) {
                log.info("{} : {}",rptCode,sql);
            }
        } catch (IllegalAccessException e) {
            log.error(e.getMessage(), e);
        }
        return sql;
    }

    public String getRptName() {
        return sqlXml.getRptName();
    }

    /**
     * 获取导出数据表头.
     * @return
     */
    public String[] getTitleCols() {
        String colNames = sqlXml.getColNames();
        String show = sqlXml.getShow();
        if ("true".equals(show)) {
            log.info(colNames);
        }
        if (StringUtils.isBlank(colNames)) {
            log.error("导出的表头没有配置");
            return null;
        }
        String [] titleCols = colNames.replaceAll("\\r|\\n","").trim().split("\\s*(,|;)\\s*");
        return titleCols;
    }

    private static String parseSql(String sqlTemplate, Object obj) throws IllegalAccessException {
        ExpressionParser paser = new SpelExpressionParser();//创建表达式解析器
        //通过evaluationContext.setVariable可以在上下文中设定变量。
        EvaluationContext context = new StandardEvaluationContext();
        if (obj instanceof Map) {
            Map<String, Object> map = (Map) obj;
            String[] sqlRow = sqlTemplate.split("\n");
            List<String> sqlRowList = Arrays.stream(sqlRow).filter(n -> n.trim().startsWith("{") && n.trim().endsWith("}")).collect(Collectors.toList());
            for (String row : sqlRowList){
                row = row.trim();
                Expression expression = paser.parseExpression(row, new TemplateParserContext());
                SpelExpression spelExpression = (SpelExpression) ((CompositeStringExpression) expression).getExpressions()[1];
                String key = spelExpression.getExpressionString().replace("#","");
                Object value = map.get(key);
                if (value != null && StringUtils.isNotBlank(value.toString())) {
                    sqlTemplate = sqlTemplate.replace(row, StringUtils.substring(row, 1, row.length() - 1));
                    context.setVariable(key, "'"+value.toString()+"'");
                } else {
                    sqlTemplate = sqlTemplate.replace(row, "");
                }
            }
        } else {
            Field[] fields = obj.getClass().getDeclaredFields();
            for (Field field : fields) {
                field.setAccessible(true);
                String name = field.getName();
                Object valueObj = field.get(obj);
                String value = String.valueOf(valueObj);
                String reg = "\\{.*" + name + ".*\\}";
                if (StringUtils.isBlank(value)) {
                    sqlTemplate = sqlTemplate.replaceAll(reg, "");
                } else {
                    Matcher matcher = Pattern.compile(reg).matcher(sqlTemplate);
                    if (matcher.find()) {
                        sqlTemplate = sqlTemplate.replace(matcher.group(), StringUtils.substring(sqlTemplate, matcher.start() + 1, matcher.end() - 1));
                        context.setVariable(name, "'" + value + "'");
                    }
                }
            }
        }
        //解析表达式，如果表达式是一个模板表达式，需要为解析传入模板解析器上下文。
        Expression expression = paser.parseExpression(sqlTemplate,new TemplateParserContext());
        //使用Expression.getValue()获取表达式的值，这里传入了Evalution上下文，第二个参数是类型参数，表示返回值的类型。
        String result = expression.getValue(context,String.class);
        return result;
    }

}
